{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simple Variational GP Classification with Pyro"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import gpytorch\n",
    "import pyro\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = torch.linspace(0, 1, 50)\n",
    "train_y = torch.sign(torch.cos(train_x * (4 * math.pi))).add(1).div(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy\n",
    "from gpytorch.models import PyroVariationalGP\n",
    "\n",
    "\n",
    "class PyroGPClassificationModel(PyroVariationalGP):\n",
    "    def __init__(self, likelihood, inducing_points):\n",
    "        variational_distribution = CholeskyVariationalDistribution(inducing_points.size(0))\n",
    "        variational_strategy = VariationalStrategy(self, inducing_points, variational_distribution)\n",
    "        super(PyroGPClassificationModel, self).__init__(\n",
    "            variational_strategy, likelihood, num_data=train_y.numel(), name_prefix=\"basic_gp_test\"\n",
    "        )\n",
    "        self.mean_module = gpytorch.means.ZeroMean()\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel(nu=0.5))\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
    "        return latent_pred\n",
    "\n",
    "    \n",
    "# Initialize model and likelihood\n",
    "likelihood = gpytorch.likelihoods.BernoulliLikelihood()\n",
    "model = PyroGPClassificationModel(likelihood, train_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyro import optim\n",
    "from pyro import infer\n",
    "\n",
    "optimizer = optim.Adam({\"lr\": 0.01})\n",
    "elbo = infer.Trace_ELBO(num_particles=64, vectorize_particles=True)\n",
    "svi = infer.SVI(model.model, model.guide, optimizer, elbo)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration 10, Loss = 69.04036617279053\n",
      "Iteration 20, Loss = 62.06441354751587\n",
      "Iteration 30, Loss = 67.1903223991394\n",
      "Iteration 40, Loss = 60.96347999572754\n",
      "Iteration 50, Loss = 58.23017120361328\n",
      "Iteration 60, Loss = 54.84669291973114\n",
      "Iteration 70, Loss = 59.56512427330017\n",
      "Iteration 80, Loss = 57.56812238693237\n",
      "Iteration 90, Loss = 59.487281799316406\n",
      "Iteration 100, Loss = 58.22996139526367\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 100\n",
    "\n",
    "for i in range(num_epochs):\n",
    "    loss = svi.step(train_x, train_y)\n",
    "    if not (i + 1) % 10:\n",
    "        print('Iteration {}, Loss = {}'.format(i + 1, loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "test_x = torch.linspace(0, 1, 200)\n",
    "pred_dist = model(test_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_y = likelihood(pred_dist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x7f7c540734e0>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQgAAADDCAYAAAB+ro88AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFYRJREFUeJzt3W1sG8eZB/A/9WLLjWPRciw7QNCAtFEjOKMRxSnatGgSSxQucIH2rMhxENj54poFkg+Hg5FIB/tO50OcyAnUAkFySFwpQNBCaVJVvgZXI0AstRcDaYpbm/KXOghssnEPtaT4hZJ9AfW694G7FCXvkEsuX3Y2/x9AUMuluMOd5cOZ4ew+Pl3XQURkpabaBSAi92KAICIpBggikmKAICIpBggikqpz+gJCiKjx5zZN07ot1ncBSAJo1TTtZafbI6LKcdSCEEJEAJzRNO0kgKCxnL2+FQA0TTsDIGkuE5EanHYxggDMoBA3lrPtQ7r1YK6PgIiU4aiLYbQcTK0A3l31FD+AG1nLm5xsj4gqy/EYBJDpSpzXNO18sa/R09PDKZ1EVdLX1+ezerwkAQJAxGqAEunuRZPxtx/A9VwvcuzYsbwbmpqaQnNzc8EFrBS3lw9gGUvB7eUD7Jext7dXus7xz5xCiKj564Q5SCmE8Bur38XyuEQQwBmn2yOiyinFrxgnhBCXhRA3s1aNAoDZ5TCel3TSBSGiynM6SHkGwEaLx8NZf59cvZ4ol4WFBdy6dQu3bt2CW882XlpawszMTLWLkdPqMvp8PqxduxZbt25FXZ29j36pxiCISmZiYgKNjY3YtGkTfD7LsbOqm5+fR319fbWLkdPqMuq6jmQyiYmJCdx33322XoNTrcl1ZmdnsWHDhqoHh1gshlgsVvbtJJNJjIyMlH07Pp8Pfr8fs7Oztv+HAYJcR9d128Hh6tWriEQimJiYKHp7sVgMAwMDGB0dxcDAAOLxOACgsbERw8PDRb+uXX6/33I7sVgMDzzwAEZGRjAyMoL+/v5M2azkWmfy+XwFddvYxSClvfTSS/j444/x4osv4tVXXy34/5PJJF555RUMDQ1lHnvqqacwNDSEpqamHP9ZWhs33jGUh1AohEAggM7Ozsxju3fvxunTp+94bjwex+DgII4fP17ScjFAkJL8fj9SqVRm+eTJkzh58iQaGhqQTCZz/OdKw8PDaGtrW/HYxo0bMTo6inA4jFgshtHRUYyPj+PgwYM4d+4cAODcuXPo6urC2NgYmpqaEAgEkEgkMDw8jEAggB07duCDDz7A0NAQnn32WRw+fBgAVjw/EAhgcHAQLS0tOH/e3g98fr8/01IYGxsDALS1tWF8fByJRAKxWAyNjY0YGxvD4uIiOjo6EAyuPgPCPnYxSEkXL17Evn37sG7dOgDAunXr8OSTT+LTTz8t+LWmp6el60KhENrb29HS0oLBwUGMj49jbGwMu3btwtGjRxEOhzPBoa2tDRs3bsTx48fx9NNPZ16js7MTwWDwjucfOXIEe/bsQXt7OwKBQEFlDgaDaGpqQlNTE06dOoW2tjYEAgGEQqE71jnBAEFKuvfee7FhwwbMzs6ioaEhM7C5devWgl6nra0t0yowJRIJtLe3r3jM7G7s2bMHBw8eRH9/P+bm5tDY2IhQKJRphfj9/sz/tLW1ob+/H+Fw5lf/O55fqGQyiWAwiP7+fjQ2NqKlpSXzOJDuapjrHnzwwRXrisEuBilramoKhw4dwsGDBzE4OFjUQGUwGMRzzz2HgYEBBAIBjI+P47XXXsusTyaTK7oYZpdg165d6OjowODgYObb22ziJ5NJ+P1+dHV14ciRI5mg8cILL6x4/uHDh3Hq1Cm0tLRk/jcUCmW2HYvFkEgkMr9wJBKJTNnM7U1PTyMej+PmzZtIJpNIJBKZdTdu3EA8HkcikVjxuoXwuWUiSk9Pj85zMSrD7WW8dOkS7r//flfPM1BxHoTp0qVL2L59e2a5t7dXerIWuxhEJMUAQURSDBBEJMUAQURSDBBEJMUAQURSDBD0lRaLxfDQQw+tOGszHo/f8dhXFSdKkWs1NKwt2WulUtanOIdCocxEqddffx1Aeuq1OW35q64kLYhcCXGEECeM+6jsOUTV1NjYKF0Xj8cxMDCAkZERxGKxzPJbb72FeDyO0dFR7N69G6Ojozhy5EgFS10ZpbhobQTAr3M8JSqEuIx04hwi21Kp2ZLd8uns7MTAwMAd051Xn2C1+kSo9vZ2+P1+tLe3Ozrnwa0cBwjjupS5PvyHNE3bZjyPyJXa29szp0+vln2CldWJUFbXcvCKSgxSBoUQESHE8xXYFlFBYrEYBgcHEY/HMy0F81JzsVgsc4LV6Ogobty4kWlJXLhwAfF4HKdPn0YikcicFOW1gc2yD1Jm5czoEEJE2JIgNwmFQpmrSZkXdQmFQrh48WLmOdlXaTIvvjI/P4+9e/cCSF+BCoDllZ5UV9YAYQxM3tA0bRjprFo5L20zNTWV9zXd3s9ze/kA95dxaWkJi4uL1S5GTm4vHyAv49LSkq3PGlCmACGE8GualgSgYXl8YhuAN3P9n91TkN18qjLg/vIB7i7jzMwMamtrXX86tdvLB1iXsaamxnb9l+JXjK70nejKejg7s9YTxrrLzKxFdhR65WWyr5ArhgMlaEEY3YfhVY8xsxYVbe3atZiZmXF14hwVmYlz1q61PwGNMynJdbZu3YrPP/8c09PTrm1JLC0toabG3WcqrC5jduo9uxggyHXq6upw9913u3qcxO2X7QNKU0Z3h0AiqioGCCKSYoAgIikGCCKSYoAgIikGCCKSYoAgIikGCCKSYoAgIikGCCKSYoAgIikGCCKSYoAgIikGCCKSYoAgIikGCCKSYoAgIqlK5ObsYuIcIjWVNTenGTiMZDnJXIHEjqtXr2Lv3r2YmJjILEcikaKX7T7HC5zuq0rsa9ZFafdt9melaLquO76Fw+EPJY+fCIfDEePvSDgcfl72Gt3d3Xoqlcp5i0ajus/n06PRaGa5pqam6GW7zynkduXKlaL+r9y37Pd15coVx/uuHPs6XxnddHNSz5Xat9mflVy37u5uXfa59JXiqsFCiA81TeuwePxNAG9qmnbeaGl0aJrWbfUaPT09+rFjxyxf3+/3I5UKAPjIcVntmwTw92houF5QJio3Xcz07bdr8JOfXIc3rk38V6xduxvT03+pdkEAFFfP69cfxcLCc6jc0N9HAB4HADQ0NEiP497eXvT19VnmF3DVkSNLB3b27Fk8//wv8Pvf31PB0twDIf4Jb7zxqO00ZYC70tq98849AOxf4tzd7kF//38XVBflVEw9f+Mb/4Y///muMpRGZgMaGhrw2GOP4ejRo0Xtu3IHiCSAJuNvP9L5OaVkEbm5uRnB4Cz+8Idm1NfXY35+Htu3b8elS5ewZs0azM3NFbx84MB+6Drwy1/+8o7n+HxvYGmpE+vXb8TOnTsLftNuaUH4fOm0a+3tb2Bs7F8z79PMUF3svivlvpYtZ5dxdvZtAO24774AmpuXqrtTsxRaz2YaPJ8vgjVrPi37vq2vB+bm5tDc3FzUcQyUqa0jhPAbf76L5YS9QQBFZ/a+dm0C0Wgn3n//LUSjnbh9+y+IRjtx9uypopanp+OYmYlbPmf37u8DAJLJ/3O0H6ptfj59n0r9r6N9Vc59bWf561/fvOL9qMos/+OPP1yRffv++7/AoUOHMDk5WXSZHY9BGHk3fw7gkJGGD0KIc2b6PSPDdxxAMFcavlxjENkq0cd/5pk6vPVWLV57bR4//nFh31huGoN49NF6fPJJDcbG5vDd7y7Xs5vKKJNdxn376vDb39binXfmsWePO1oQxezDb36zHp99VoMLF+awY0f5M4bZLWNZxyC8mJvTTIi8sFDdcjhlll+BJNQ5meVXvwWR/gzW1bkznaAVzqS0YFagWaGqMj9Qda4aii5cbW36Xv0Akb5XqT4YICx45xsrfc8WhDuo2KJjgLBgRnjVD0gGCHdRsT4YICx4ZwxCvT6vFbM+FherWw6n2ILwCK+1IFTq81qpr+eYULUwQFjwTgsifa/SN5YVrwVsleqDAcIC+7zu4oX60PXlFpBK9cEAYcELByTgnQDhhRaEOX5SU6OjRqFPnUJFrRyzz2sO8qnKK10ML3T5VK0LBggLnJjjLl4IEKrWBQOEBS90MVTt81rxQn2o2t1jgLDghQNS1T6vFS+MQTBAeIiXmrSqHZBWvDAPQtX6YICw4IUAoeqgmBWzBeGF+uAYhAd4qUmr2gFpxQtdvuVp71UuSIEYICwsH5Bs0rqBFwLEcn2odV4MA4SF5etBVLkgDngpQHihi6FqfTBAWPDSN5ZqTVorXqoPBggP8MLpxYuL5hwItZq0VjhoXD2lSL2XM/emEOKEcR91uq1K4SClu3hpTMicpasKRwHCZu7NqBDiMtJXtlYCm7TuwjGh6nH6/bIPwIfG33EAEQDnVz0nczl8VTBAuIsXWnSqdjGcBgg/gBtZy5ssnhM08nK2apr2cq4Xs5MarBKp7WZm6gDci9nZxYLTlbkl9d4XX6wBsAXA/B3vwS1lzCW7jLdvp99LKnXne6mWQvfhtWsNADZD12cxNXWtPIVapRT1XPYeqhkUhBAdQoiI0R2xZDcRSbmTvty+nb5fWqoraltuSEqzfn26v75uXb1ledxQxnzMMjY3m2MP1u+lWgopy1131Rj3ayr6Hpxuy+kgZc7cm0KIqJF5C8a6IBTghSatNwcpq1sOJ1StD6cBwjL3ZlZuTg3L+Ti3Gcuux5/V3IUBonocBQhN084DgDHGkDSXAYxmrX/CaEVczlrvat44IL03D0Lt+kjfqxawS5Gb847cm8zNWX1eakHU1qp/CUBV64MzKS3wG8tdvFAfDBAe4qVBStVm7lnxQovuKzkG4VXLB6QPuqJdeLYg3EXVMSEGCAs+X3a/t8qFKRIDhLuoWh8MEBKqN2vNM1FVOyCteOF6ELzknMeo/q2l6jeWFdXrAlC3PhggJFQfqDT7vOaZkCrz0pgQWxAeofq3lqrfWFa8MCbEnzk9xisBQrVvLBmv1AcDhEeoPjCm6jeWjOqDxuYsUP7M6RGqZ/hW9RtLRv0xofS9avXBACGh+gGp6s9qMl7pYqg2s5UBQsIrB6Rq31gyrI/qYICQUP+AVLPPK6N6i44BwmPUHxRL36t2QMqYgc7M96EaVWe2MkBIeOUbyytjEF6pDwYIj/DKAemVAKF+ly99r1p9MEBIqN7FUPUbS0b9AKHmmFAlUu/lXO9WZkVOTt5EJBLBxMQEAODq1as5lycnJ3Out/MapVgeGztrvI8y7aAKM9/HM8/8o+39YOc5xS6b+TnsblPTxgGo14KArutF38LhcGs4HO4y/o6Gw+HWQtZn37q7u/VUKpX3duXKFVvPc3r74Q8XdEDXOzr+Q6+pqdGj0aieSqX0aDSac3n//v0519t5jVIsA/+pA7r+3ntzVduHTm6ry/i97y3qgK77fI/Y3g/l3Nf79+8vaJvARzqg62fOzFZtH8pu3d3duuxz6dMdnB5nJOb9UNO0M1bZs/Ktz9bT06MfO3Ys7zanpqYqknjka1/7LywtPQ7gdwA+K/v2Su8fAAQA/AANDWMrsixVah86kV1Gv9+PVOp3ANoADAP4azWLVqQnAdwL4DtoaLhQkexmduu5t7cXfX19lj8PlTv1np3UfBluSb0HAHv2fB+/+Q0A/MC4qenhh/8OP/3pv6/Yt6ql3jt79ix+9KMU/vY3AOiS/o8KIpFW9PW9XpEUgkqk3iuEW1LvAcDPfgZcvTqEP/5RQ21tHRYXF7B5czO++GIq5/LU1BTq6qzXf/vb3wEA/OlPnxT0msUvX8GOHY3YuXNnVfahU8up95rxyCP/gl/96j3U1jbY2g/l3tcLCwtobi50m58hGNxiWR/l3ofFKmvqPRvrXau5Gdi8+T1Eo1/i44/3Ihr9ErW1r+ZdPnDgmnT9li1D2LJlqODXLH7Zj6mpyWrvypL48svPEI3O2d4P5d7XBw5cK2KbWzE5qVZ9OB2DaAUgNE07afxKcUbTtPNCCL+maUnZeqvXctsYRLHcXj6AZSwFt5cPKM0YRCVS71mtJyIFMPUeEUlxJiURSTFAEJEUAwQRSTFAEJEUAwQRSTFAEJEUAwQRSTFAEJEUAwQRSTFAEJEUAwQRSTFAEJEUAwQRSTFAEJEUAwQRSTFAEJEUAwQRSTFAEJEUAwQRSVUiN+cJ4z7qdFtEVFmOAoRxWXtomnYGgHmZ+9WiQojLAOJOtkVElef0qtb7AHxo/B0HEAGw+tL2hzRNG3a4HSKqgnLn5gSAYL7EvSY35eYsltvLB7CMpeD28gGK5OY0g4IQokMIETG6I5bclJvTCbeXD2AZS8Ht5QOclzFvgJAMLsbNcQfkyL1p/O8No4txHUDQUWmJqKLyBog8mbHeBSCMv4MAzgCAmZsTgIblwcltAN4svqhEVGmVyM35hBCiC8Bl5uYkUgtzcxKRFGdSEpEUAwQRSTFAEJEUAwQRSTFAEJEUAwQRSTFAEJEUAwQRSTFAEJEUAwQRSTFAEJEUAwQRSTFAEJEUAwQRSTFAEJEUAwQRSTFAEJFUSQKEJGGOuS5n5i0icq9SpN6LAPi1ZJ2dzFtE5FKOA4Tx4Zel1duH9KXxgeXMW0SkiHKPQdjJvEVELlX2zFqF6O3trXYRiCiL08xa+eTMvJWtr6/PZ+P1iKiCnGbWspSVWcsy8xYRqaEUv2J0pe9EV9bD2Zm1rDJvEZECfLquV7sMRORSnEnpcXYnqnEim9rKNVnRVb9iZDO6LEkArZqmvVzo+kqwUUZzgHebpmndFS0cVk5UE0IEhRCtVt08owvYAaDi+9HGPmxFevwKmqYNV7h4ZhnsHovBauSiNervTQDbLNbZOgZkXNmCyDcD0w0zNG2UMQLgjHHABI3lSnP1RDWb9fjPRmAIurSeW7H8q168GmUs52RFVwYI5H9Tbjjw85UhmPVY3FiutLwT1YxvlGr9upRzHxrfzP8DAJqmvVylQW47x9oJ4z7owoF4R5MV3Rog8r0pN8zQzFkGTdNOZjU3WwFolSpYgZryP6Vs8tXjtwBsEkK0VnGMJF89n0e65XBz1fM8wa0BwjOMJuf5Kn2z5JyoVuXWg13Xs34u78r35EoTQviR3s8vAfi5EKIaLcVcbE9WtOLWAJHvTTl60yVitwyRagxQGt7FctcmM1HNOKiBdL++yxhMbapC/znfPryO5b51EukWRaXlK2MUwEvG4OUhAK4IYll1bHkM2OXWAJHvwHb0pkskXxkhhIiao97VGKTMMVHNnMg2nPXLgN/iJcot3z4czlrvhzEeUWF569lk7Mvk6sfLrZyTFV07Ucr4Vosj66cjIcQ5TdPCsvVuKmPWdTJuIP0NtFeB5nzF2aznGwC+Va2WmI0yPm+sb6rWsVgurg0QRFR9bu1iEJELMEAQkRQDBBFJMUAQkRQDBBFJMUAQkRQDBBFJ/T+l4cFLRRqeaQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 288x216 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Initialize fig and axes for plot\n",
    "f, ax = plt.subplots(1, 1, figsize=(4, 3))\n",
    "ax.plot(train_x.numpy(), train_y.numpy(), 'k*')\n",
    "# Get the predicted labels (probabilites of belonging to the positive class)\n",
    "# Transform these probabilities to be 0/1 labels\n",
    "pred_labels = pred_y.mean.ge(0.5).float()\n",
    "ax.plot(test_x.numpy(), pred_labels.numpy(), 'b')\n",
    "ax.set_ylim([-1, 2])\n",
    "ax.legend(['Observed Data', 'Mean', 'Confidence'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
